package vfile.djl;

import java.awt.image.BufferedImage;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.bytedeco.opencv.opencv_core.Mat;

import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import util.ConsoleProgressBar;
import util.VideoTool;
import util.WorkId;
import vfile.translate.NPtKTranslator;
import vfile.translate.PtGTranslator;

public class Face {

    Predictor<Image, DetectedObjects> predictor;
    int counter;
    private ImageFactory factory;

    public static void main(String[] args) throws Exception {
    	//说了你可能不行，这个就是为了解决bug的
    	Mat one = new Mat();
    	
    	String videoType = "mp4";
    	
    	System.setProperty("ai.djl.default_engine", "PyTorch");
        Path imageFile = Paths.get("src/main/resources/img/2m.jpg");
        Path vidoePath = Paths.get("src/main/resources/video/ds."+videoType);
        Path pushPath = Paths.get("build/output/file2."+videoType);
        
        Image image = ImageFactory.getInstance().fromFile(imageFile); 

        NDManager manager = NDManager.newBaseManager();
        NDArray img = change(image,manager);
        
        List<Image> driving_video = VideoTool.getKeyFrame(vidoePath.toString());
        List<BufferedImage> resultVideo = new ArrayList();

        Criteria<NDArray, Map> kpDetector =
                Criteria.builder()
                        .setTypes(NDArray.class, Map.class)
                        .optTranslator(new NPtKTranslator())
                        .optEngine("PyTorch")
                        .optModelPath(Paths.get("F:/1/model/kpdetector.pt"))
                        .build();
        Predictor<NDArray, Map> kPredictor = ModelZoo.loadModel(kpDetector).newPredictor();

       
        //saveBoundingBoxImage(driving_video.get(0));
        Criteria<List, Image> generatorCr =
                Criteria.builder()
                        .setTypes(List.class, Image.class)
                        .optEngine("PyTorch")
                        .optTranslator(new PtGTranslator())
                        .optModelPath(Paths.get("F:/1/model/generator.pt"))
                        .build();

        Predictor<List, Image> generator = ModelZoo.loadModel(generatorCr).newPredictor();
        
        Map kp_source = kPredictor.predict(img);
        
        
        
        
        Map kp_driving_initial = kPredictor.predict(change(driving_video.get(0),manager));
        int count = 0;
        int total = driving_video.size();
        //进度条打印
        ConsoleProgressBar bar = new ConsoleProgressBar(total);
		 
        for(Image dimg : driving_video){
        	
        	bar.draw(count++, 1);  
            List<Object> g = new ArrayList<>();
            Map kp_driving = kPredictor.predict(change(dimg,manager));
            //Map kp_norm = ImgTool.normalize_kp(kp_source, kp_driving, kp_driving_initial, 1, true, true);
            //System.out.println(kp_norm);
            g.add(img);
            g.add(kp_driving);
            g.add(kp_source);
            g.add(kp_driving_initial);
            resultVideo.add((BufferedImage)generator.predict(g).getWrappedImage()); 
        } 
        VideoTool.push(vidoePath.toString(), pushPath.toString(), resultVideo,videoType); 
        manager.close();

    }
    private static NDArray change(Image img, NDManager manager){
    	NDArray driving0 = img.toNDArray(manager); 
        driving0 = NDImageUtils.resize(driving0, 256, 256);
        driving0 = driving0.div(255); 
        driving0 = driving0.transpose(2, 0, 1); 
        driving0 = driving0.toType(DataType.FLOAT32,false);
        driving0 = driving0.broadcast(new Shape(1,3,256,256));
        return driving0;
    }
    private static void saveBoundingBoxImage(Image img)
            throws IOException {
        Path outputDir = Paths.get("build/output");
        Files.createDirectories(outputDir);

        // Make image copy with alpha channel because original image was jpg
        Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB); 
        Path imagePath = outputDir.resolve(WorkId.sortUID()+".png");
        // OpenJDK can't save jpg with alpha channel
        newImage.save(Files.newOutputStream(imagePath), "png");
       
    }

    public Face() throws IOException, ai.djl.ModelException {

        factory = ImageFactory.getInstance();
    }


}
